{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "fdd7864c-93e6-4eb4-a923-b80d2ae4377d",
   "metadata": {},
   "source": [
    "# JSONFormer\n",
    "\n",
    "[JSONFormer](https://github.com/1rgs/jsonformer) is a library that wraps local HuggingFace pipeline models for structured decoding of a subset of the JSON Schema.\n",
    "\n",
    "It works by filling in the structure tokens and then sampling the content tokens from the model.\n",
    "\n",
    "**Warning - this module is still experimental**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1617e327-d9a2-4ab6-aa9f-30a3167a3393",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "!pip install --upgrade jsonformer > /dev/null"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "66bd89f1-8daa-433d-bb8f-5b0b3ae34b00",
   "metadata": {},
   "source": [
    "### HuggingFace Baseline\n",
    "\n",
    "First, let's establish a qualitative baseline by checking the output of the model without structured decoding."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d4d616ae-4d11-425f-b06c-c706d0386c68",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import logging\n",
    "\n",
    "logging.basicConfig(level=logging.ERROR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1bdc7b60-6ffb-4099-9fa6-13efdfc45b04",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from typing import Optional\n",
    "from langchain.tools import tool\n",
    "import os\n",
    "import json\n",
    "import requests\n",
    "\n",
    "HF_TOKEN = os.environ.get(\"HUGGINGFACE_API_KEY\")\n",
    "\n",
    "\n",
    "@tool\n",
    "def ask_star_coder(query: str, temperature: float = 1.0, max_new_tokens: float = 250):\n",
    "    \"\"\"Query the BigCode StarCoder model about coding questions.\"\"\"\n",
    "    url = \"https://api-inference.huggingface.co/models/bigcode/starcoder\"\n",
    "    headers = {\n",
    "        \"Authorization\": f\"Bearer {HF_TOKEN}\",\n",
    "        \"content-type\": \"application/json\",\n",
    "    }\n",
    "    payload = {\n",
    "        \"inputs\": f\"{query}\\n\\nAnswer:\",\n",
    "        \"temperature\": temperature,\n",
    "        \"max_new_tokens\": int(max_new_tokens),\n",
    "    }\n",
    "    response = requests.post(url, headers=headers, data=json.dumps(payload))\n",
    "    response.raise_for_status()\n",
    "    return json.loads(response.content.decode(\"utf-8\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "d5522977-51e8-40eb-9403-8ab70b14908e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "prompt = \"\"\"You must respond using JSON format, with a single action and single action input.\n",
    "You may 'ask_star_coder' for help on coding problems.\n",
    "\n",
    "{arg_schema}\n",
    "\n",
    "EXAMPLES\n",
    "----\n",
    "Human: \"So what's all this about a GIL?\"\n",
    "AI Assistant:{{\n",
    "  \"action\": \"ask_star_coder\",\n",
    "  \"action_input\": {{\"query\": \"What is a GIL?\", \"temperature\": 0.0, \"max_new_tokens\": 100}}\"\n",
    "}}\n",
    "Observation: \"The GIL is python's Global Interpreter Lock\"\n",
    "Human: \"Could you please write a calculator program in LISP?\"\n",
    "AI Assistant:{{\n",
    "  \"action\": \"ask_star_coder\",\n",
    "  \"action_input\": {{\"query\": \"Write a calculator program in LISP\", \"temperature\": 0.0, \"max_new_tokens\": 250}}\n",
    "}}\n",
    "Observation: \"(defun add (x y) (+ x y))\\n(defun sub (x y) (- x y ))\"\n",
    "Human: \"What's the difference between an SVM and an LLM?\"\n",
    "AI Assistant:{{\n",
    "  \"action\": \"ask_star_coder\",\n",
    "  \"action_input\": {{\"query\": \"What's the difference between SGD and an SVM?\", \"temperature\": 1.0, \"max_new_tokens\": 250}}\n",
    "}}\n",
    "Observation: \"SGD stands for stochastic gradient descent, while an SVM is a Support Vector Machine.\"\n",
    "\n",
    "BEGIN! Answer the Human's question as best as you are able.\n",
    "------\n",
    "Human: 'What's the difference between an iterator and an iterable?'\n",
    "AI Assistant:\"\"\".format(\n",
    "    arg_schema=ask_star_coder.args\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "9148e4b8-d370-4c05-a873-c121b65057b5",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " 'What's the difference between an iterator and an iterable?'\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from transformers import pipeline\n",
    "from langchain.llms import HuggingFacePipeline\n",
    "\n",
    "hf_model = pipeline(\n",
    "    \"text-generation\", model=\"cerebras/Cerebras-GPT-590M\", max_new_tokens=200\n",
    ")\n",
    "\n",
    "original_model = HuggingFacePipeline(pipeline=hf_model)\n",
    "\n",
    "generated = original_model.predict(prompt, stop=[\"Observation:\", \"Human:\"])\n",
    "print(generated)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b6e7b9cf-8ce5-4f87-b4bf-100321ad2dd1",
   "metadata": {},
   "source": [
    "***That's not so impressive, is it? It didn't follow the JSON format at all! Let's try with the structured decoder.***"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "96115154-a90a-46cb-9759-573860fc9b79",
   "metadata": {},
   "source": [
    "## JSONFormer LLM Wrapper\n",
    "\n",
    "Let's try that again, now providing a the Action input's JSON Schema to the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "30066ee7-9a92-4ae8-91bf-3262bf3c70c2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "decoder_schema = {\n",
    "    \"title\": \"Decoding Schema\",\n",
    "    \"type\": \"object\",\n",
    "    \"properties\": {\n",
    "        \"action\": {\"type\": \"string\", \"default\": ask_star_coder.name},\n",
    "        \"action_input\": {\n",
    "            \"type\": \"object\",\n",
    "            \"properties\": ask_star_coder.args,\n",
    "        },\n",
    "    },\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "0f7447fe-22a9-47db-85b9-7adf0f19307d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from langchain_experimental.llms import JsonFormer\n",
    "\n",
    "json_former = JsonFormer(json_schema=decoder_schema, pipeline=hf_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "d865e049-a5c3-4648-92db-8b912b7474ee",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\"action\": \"ask_star_coder\", \"action_input\": {\"query\": \"What's the difference between an iterator and an iter\", \"temperature\": 0.0, \"max_new_tokens\": 50.0}}\n"
     ]
    }
   ],
   "source": [
    "results = json_former.predict(prompt, stop=[\"Observation:\", \"Human:\"])\n",
    "print(results)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32077d74-0605-4138-9a10-0ce36637040d",
   "metadata": {
    "tags": []
   },
   "source": [
    "**Voila! Free of parsing errors.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da63ce31-de79-4462-a1a9-b726b698c5ba",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
